using OMEinsum

function make_swap_perm(s1, s2, nD)
    result = collect(1:nD)
    result[s1] = s2
    result[s2] = s1
    return Tuple(result)
end

function multiply_ith_dimension(Pi, i, X)
    X = permutedims(X, make_swap_perm(1, i, ndims(X)))
    shape = size(X)
    X = reshape(X, (size(X, 1), :))

    X = Pi * X

    X = reshape(X, (size(Pi, 1), shape[2:end]...))
    return permutedims(X, make_swap_perm(1, i, ndims(X)))
end

function outer(pis)
    _pi = pis[1]
    for pi_i in pis[2:end]
        _pi = kron(_pi, pi_i)
    end
    _pi = reshape(_pi, reverse(length.(pis))...)
    return permutedims(_pi, collect(ndims(_pi):-1:1))
end

function batch_multiply_ith_dimension(P, i, X)
    P = permutedims(P, make_swap_perm(2, 1+i, ndims(P)))
    X = permutedims(X, make_swap_perm(1, i, ndims(X)))

    Pshape = size(P)
    P = reshape(P, (Pshape[1], Pshape[2], :))
    X = reshape(X, (size(X, 1), :))

    X = ein"ijl, jl -> il"(P, X)
    X = reshape(X, (Pshape[1], Pshape[3:end]...))

    return permutedims(X, make_swap_perm(1, i, ndims(X)))
end
